{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import gym\n",
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "env = gym.make('Pendulum-v1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class CriticNet(torch.nn.Module):\n",
    "    def __init__(self, state_dim=env.observation_space.shape[0], output_dims=1):\n",
    "        super(CriticNet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(state_dim, 64)\n",
    "        self.relu1 = torch.nn.ReLU()\n",
    "        self.fc2 = torch.nn.Linear(64, 64)\n",
    "        self.relu2 = torch.nn.ReLU()\n",
    "        # self.fc3 = torch.nn.Linear(64, 64)\n",
    "        # self.relu3 = torch.nn.ReLU()\n",
    "        self.fc4 = torch.nn.Linear(64, output_dims)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu1(self.fc1(x))\n",
    "        x = self.relu2(self.fc2(x))\n",
    "        # x = self.relu3(self.fc3(x))\n",
    "        x = self.fc4(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "55f44b374b3b6bbb"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class ActorNet(torch.nn.Module):\n",
    "    def __init__(self, state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0]):\n",
    "        super(ActorNet, self).__init__()\n",
    "        self.fc1 = torch.nn.Linear(state_dim, 64)\n",
    "        self.relu1 = torch.nn.ReLU()\n",
    "\n",
    "        self.fc2 = torch.nn.Linear(64, 64)\n",
    "        self.relu2 = torch.nn.ReLU()\n",
    "\n",
    "        # self.fc3 = torch.nn.Linear(64, 64)\n",
    "        # self.relu3 = torch.nn.ReLU()\n",
    "\n",
    "        self.fc_mu = torch.nn.Linear(64, action_dim)\n",
    "        self.fc_sigma = torch.nn.Linear(64, action_dim)\n",
    "        self.tanh = torch.nn.Tanh()\n",
    "        self.softplus = torch.nn.Softplus()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu1(self.fc1(x))\n",
    "        x = self.relu2(self.fc2(x))\n",
    "        # x = self.relu3(self.fc3(x))\n",
    "        mu = self.tanh(self.fc_mu(x))\n",
    "        sigma = self.softplus(self.fc_sigma(x))\n",
    "        # return torch.square(mu), sigma\n",
    "        return mu * 2, sigma\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7c5e67c8d79b4fbf"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class ProximalPolicyOptimization:\n",
    "    def __init__(self, state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0]):\n",
    "        self.state_dim = state_dim\n",
    "        self.actor_net = ActorNet(state_dim, action_dim)\n",
    "        self.old_actor_net = ActorNet(state_dim, action_dim)\n",
    "        self.critic_net = CriticNet(state_dim, 1)\n",
    "        self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=1e-4)\n",
    "        self.critic_optimizer = torch.optim.Adam(self.critic_net.parameters(), lr=2e-4)\n",
    "        self.epsilon = 0.2\n",
    "        self.gamma = 0.9\n",
    "\n",
    "    def choose_action(self, state):\n",
    "        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)\n",
    "        with torch.no_grad():\n",
    "            mu, sigma = self.actor_net(state)\n",
    "        action = torch.clip(torch.distributions.Normal(mu, sigma).sample(), -2, 2)\n",
    "        return [action.item()]\n",
    "\n",
    "    def compute_discounted_reward(self, rewards, gamma, new_state):\n",
    "        v_ = self.critic_net(torch.tensor(new_state, dtype=torch.float32).unsqueeze(0))\n",
    "        v_ = v_.item()\n",
    "        discounted_reward = []\n",
    "        for r in rewards[::-1]:\n",
    "            v_ = r + gamma * v_\n",
    "            discounted_reward.append(v_)\n",
    "        discounted_reward.reverse()\n",
    "        return discounted_reward\n",
    "\n",
    "    def train(self, state, action, reward):\n",
    "        self.old_actor_net.load_state_dict(self.actor_net.state_dict())\n",
    "        action = torch.tensor(action, dtype=torch.float32)\n",
    "\n",
    "        for _ in range(10):\n",
    "            # 计算advantage，也就是td-error\n",
    "            reward1 = torch.tensor(reward.copy(), dtype=torch.float32)\n",
    "            state1 = torch.tensor(state.copy(), dtype=torch.float32)\n",
    "            adv = reward1 - self.critic_net(state1)\n",
    "            pi = torch.distributions.Normal(*self.actor_net(state1))\n",
    "            oldpi = torch.distributions.Normal(*self.old_actor_net(state1))\n",
    "            ratio = torch.exp(pi.log_prob(action) - oldpi.log_prob(action) + 1e-8)\n",
    "            loss = -torch.mean(\n",
    "                torch.min(ratio * adv,\n",
    "                          torch.clamp(ratio, 1. - 0.2, 1. + 0.2) * adv)\n",
    "            )\n",
    "\n",
    "            self.actor_optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            self.actor_optimizer.step()\n",
    "\n",
    "        for _ in range(10):\n",
    "            reward2 = torch.tensor(reward.copy(), dtype=torch.float32)\n",
    "            state2 = torch.tensor(state.copy(), dtype=torch.float32)\n",
    "            adv = reward2 - self.critic_net(state2)\n",
    "            loss = torch.mean(torch.square(adv))\n",
    "            self.critic_optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            self.critic_optimizer.step()\n",
    "\n",
    "    def model_save(self, path):\n",
    "        torch.save({\n",
    "            'actor_net_model_state_dict': self.actor_net.state_dict(),\n",
    "            'old_actor_net_model_state_dict': self.old_actor_net.state_dict(),\n",
    "            'critic_net_model_state_dict': self.critic_net.state_dict(),\n",
    "            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),\n",
    "            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),\n",
    "        }, path)\n",
    "\n",
    "    def model_load(self, path):\n",
    "        checkpoint = torch.load(path)\n",
    "        self.actor_net.load_state_dict(checkpoint['actor_net_model_state_dict'])\n",
    "        self.old_actor_net.load_state_dict(checkpoint['old_actor_net_model_state_dict'])\n",
    "        self.critic_net.load_state_dict(checkpoint['critic_net_model_state_dict'])\n",
    "        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])\n",
    "        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b3492a0ddb8c5d59"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "log_dir = './runs'\n",
    "if os.path.exists(log_dir):\n",
    "    try:\n",
    "        shutil.rmtree(log_dir)\n",
    "        print(f'文件夹 {log_dir} 已成功删除。')\n",
    "    except OSError as error:\n",
    "        print(f'删除文件夹 {log_dir} 失败: {error}')\n",
    "else:\n",
    "    os.makedirs(log_dir)\n",
    "    print(f'文件夹 {log_dir} 不存在，已创建文件夹 {log_dir}。')"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "35df1b21fd9ed87d",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "summary_writer = SummaryWriter(log_dir=log_dir)\n",
    "batch = 32\n",
    "ppo = ProximalPolicyOptimization()\n",
    "episode = 500\n",
    "steps = 320\n",
    "all_reward = []\n",
    "for epoch in range(episode):\n",
    "    start_time = time.time()\n",
    "    state, _ = env.reset()\n",
    "    step = 0\n",
    "    buffer = []\n",
    "    episode_rewards = 0\n",
    "    buffer_state, buffer_action, buffer_reward = [], [], []\n",
    "    while step < steps:\n",
    "        # choose action\n",
    "        action = ppo.choose_action(state)\n",
    "        new_state, reward, done, _, _ = env.step(action)\n",
    "        buffer_state.append(state)\n",
    "        buffer_action.append(action)\n",
    "        buffer_reward.append((reward + 6) / 10)\n",
    "        state = new_state\n",
    "        episode_rewards += reward\n",
    "        step += 1\n",
    "        if (step + 1) % batch == 0 or step == steps - 1:\n",
    "            discounted_reward = ppo.compute_discounted_reward(buffer_reward, ppo.gamma, new_state)\n",
    "            # buffer_state = torch.FloatTensor(buffer_state)\n",
    "            # buffer_action = torch.tensor(buffer_action).view(-1, 1)\n",
    "            # discounted_reward = torch.tensor(discounted_reward).view(-1, 1)\n",
    "            buffer_state, buffer_action, discounted_reward = np.vstack(buffer_state), np.vstack(\n",
    "                buffer_action), np.vstack(discounted_reward)\n",
    "            ppo.train(buffer_state, buffer_action, discounted_reward)\n",
    "            buffer_state, buffer_action, buffer_reward = [], [], []\n",
    "    all_reward.append(episode_rewards)\n",
    "    summary_writer.add_scalar('episode_rewards', episode_rewards, epoch)\n",
    "    print(\"Epoch/Episode: {}/{},reward: {}\".format(epoch + 1, episode, episode_rewards))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "6e53e7cbb3f6672f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "env = gym.make(\"Pendulum-v1\", render_mode='human')\n",
    "state, _ = env.reset()\n",
    "step = 0\n",
    "episode_rewards = 0\n",
    "while True:\n",
    "    state = torch.tensor(state)\n",
    "    a = ppo.choose_action(state)\n",
    "    new_state, reward, done, _, _ = env.step(a)\n",
    "    step += 1\n",
    "    state = new_state"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "f03555de6813f9fa"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "env.close()"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ae94b2deb96c50d7"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "ppo.model_save('PPO_Pendulum_V1.pth')"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3ab2211e00baad39"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
